In [1]:
# Define experiment parameters
year = "200506"
target_col = "white_collar"  # 'white_collar', 'blue_collar', 'has_occ'
sample_weight_col = 'women_weight'
In [2]:
# Define resource utilization parameters
random_state = 42
n_jobs_clf = 16
n_jobs_cv = 4
cv_folds = 5
In [3]:
import numpy as np
np.random.seed(random_state)

import pandas as pd
pd.set_option('display.max_columns', 500)

import matplotlib.pylab as pl

from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score

from sklearn.utils.class_weight import compute_class_weight

import lightgbm
from lightgbm import LGBMClassifier

from sklearn.model_selection import train_test_split
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.model_selection import StratifiedKFold

import shap

import pickle
from joblib import dump, load

Prepare Dataset

In [4]:
# Load dataset
dataset = pd.read_csv(f"data/women_work_data_{year}.csv")
print("Loaded dataset: ", dataset.shape)
dataset.head()
Loaded dataset:  (116417, 25)
Out[4]:
case_id_str country_code cluster_no hh_no line_no state hh_religion caste occupation no_occ white_collar blue_collar has_occ women_weight urban women_anemic total_children obese_female age years_edu wealth_index hh_members no_children_below5 freq_tv year
0 28 1 1 2 IA5 28001 1 2 [ap] andhra pradesh hindu scheduled tribe 0 1 0 0 0 2.474583 0 NaN 4 NaN 47 0 2 6 2 0.0 2005
1 28 1 1 4 IA5 28001 1 4 [ap] andhra pradesh hindu scheduled tribe 5 0 0 1 1 2.474583 0 0.0 2 0.0 23 2 2 6 2 3.0 2005
2 28 1 2 2 IA5 28001 2 2 [ap] andhra pradesh hindu other backward class 5 0 0 1 1 2.474583 0 0.0 1 0.0 20 0 2 5 0 3.0 2005
3 28 1 2 5 IA5 28001 2 5 [ap] andhra pradesh hindu other backward class 5 0 0 1 1 2.474583 0 1.0 0 0.0 40 0 2 5 0 2.0 2005
4 28 1 4 3 IA5 28001 4 3 [ap] andhra pradesh hindu other backward class 8 0 0 1 1 2.474583 0 0.0 2 0.0 31 7 3 5 0 3.0 2005
In [5]:
# See distribution of target values
print("Target column distribution:\n", dataset[target_col].value_counts(dropna=False))
Target column distribution:
 0    106959
1      9458
Name: white_collar, dtype: int64
In [6]:
# Drop samples where the target is missing
dataset.dropna(axis=0, subset=[target_col, sample_weight_col], inplace=True)
print("Drop missing targets: ", dataset.shape)
Drop missing targets:  (116417, 25)
In [7]:
# Drop samples where age < 21
dataset = dataset[dataset['age'] >= 21]
print("Drop under-21 samples: ", dataset.shape)
Drop under-21 samples:  (89163, 25)
In [8]:
# See new distribution of target values
print("Target column distribution:\n", dataset[target_col].value_counts(dropna=False))
Target column distribution:
 0    80649
1     8514
Name: white_collar, dtype: int64
In [9]:
# Post-processing

# Group SC/ST castes together
dataset['caste'][dataset['caste'] == 'scheduled caste'] = 'sc/st'
dataset['caste'][dataset['caste'] == 'scheduled tribe'] = 'sc/st'
if year == "200506":
    dataset['caste'][dataset['caste'] == '9'] = "don\'t know"

# Fix naming for General caste
dataset['caste'][dataset['caste'] == 'none of above'] = 'general'

if year == "201516":
    # Convert wealth index from str to int values
    wi_dict = {'poorest': 0, 'poorer': 1, 'middle': 2, 'richer': 3, 'richest': 4}
    dataset['wealth_index'] = [wi_dict[wi] for wi in dataset['wealth_index']]
/home/chaitanya/miniconda3/envs/tf_gpu/lib/python3.6/site-packages/ipykernel_launcher.py:5: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  """
/home/chaitanya/miniconda3/envs/tf_gpu/lib/python3.6/site-packages/ipykernel_launcher.py:7: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  import sys
/home/chaitanya/miniconda3/envs/tf_gpu/lib/python3.6/site-packages/ipykernel_launcher.py:10: SettingWithCopyWarning: 
A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy
  # Remove the CWD from sys.path while we load stuff.
In [10]:
# Define feature columns
x_cols_categorical = ['state', 'hh_religion', 'caste']
x_cols_binary = ['urban', 'women_anemic', 'obese_female']
x_cols_numeric = ['age', 'years_edu', 'wealth_index', 'hh_members', 'no_children_below5', 'total_children', 'freq_tv']
x_cols = x_cols_categorical + x_cols_binary + x_cols_numeric
print("Feature columns:\n", x_cols)
Feature columns:
 ['state', 'hh_religion', 'caste', 'urban', 'women_anemic', 'obese_female', 'age', 'years_edu', 'wealth_index', 'hh_members', 'no_children_below5', 'total_children', 'freq_tv']
In [11]:
# Drop samples with missing values in feature columns
dataset.dropna(axis=0, subset=x_cols, inplace=True)
print("Drop missing feature value rows: ", dataset.shape)
Drop missing feature value rows:  (77762, 25)
In [12]:
# Separate target column
targets = dataset[target_col]
# Separate sampling weight column
sample_weights = dataset[sample_weight_col]
# Drop columns which are not part of features
dataset.drop(columns=[col for col in dataset.columns if col not in x_cols], axis=1, inplace=True)
print("Drop extra columns: ", dataset.shape)
Drop extra columns:  (77762, 13)
In [13]:
# Obtain one-hot encodings for the caste column
dataset = pd.get_dummies(dataset, columns=['caste'])
x_cols_categorical.remove('caste')  # Remove 'caste' from categorical variables list
print("Caste to one-hot: ", dataset.shape)
Caste to one-hot:  (77762, 16)
In [14]:
dataset_display = dataset.copy()
print("Create copy for visualization: ", dataset_display.shape)
dataset_display.head()
Create copy for visualization:  (77762, 16)
Out[14]:
state hh_religion urban women_anemic total_children obese_female age years_edu wealth_index hh_members no_children_below5 freq_tv caste_don't know caste_general caste_other backward class caste_sc/st
1 [ap] andhra pradesh hindu 0 0.0 2 0.0 23 2 2 6 2 3.0 0 0 0 1
3 [ap] andhra pradesh hindu 0 1.0 0 0.0 40 0 2 5 0 2.0 0 0 1 0
4 [ap] andhra pradesh hindu 0 0.0 2 0.0 31 7 3 5 0 3.0 0 0 1 0
5 [ap] andhra pradesh muslim 0 0.0 2 0.0 22 0 3 7 2 3.0 0 0 1 0
6 [ap] andhra pradesh hindu 0 0.0 2 1.0 25 6 3 3 0 3.0 0 1 0 0
In [15]:
# Obtain integer encodings for other categorical features
for col in x_cols_categorical:
    dataset[col] = pd.factorize(dataset[col])[0]
print("Categoricals to int encodings: ", dataset.shape)
Categoricals to int encodings:  (77762, 16)
In [16]:
dataset.head()
Out[16]:
state hh_religion urban women_anemic total_children obese_female age years_edu wealth_index hh_members no_children_below5 freq_tv caste_don't know caste_general caste_other backward class caste_sc/st
1 0 0 0 0.0 2 0.0 23 2 2 6 2 3.0 0 0 0 1
3 0 0 0 1.0 0 0.0 40 0 2 5 0 2.0 0 0 1 0
4 0 0 0 0.0 2 0.0 31 7 3 5 0 3.0 0 0 1 0
5 0 1 0 0.0 2 0.0 22 0 3 7 2 3.0 0 0 1 0
6 0 0 0 0.0 2 1.0 25 6 3 3 0 3.0 0 1 0 0
In [17]:
# Create Training, Validation and Test sets
X_train, X_test, Y_train, Y_test, W_train, W_test = train_test_split(dataset, targets, sample_weights, test_size=0.05, random_state=random_state, stratify=targets)
# X_train, X_val, Y_train, Y_val, W_train, W_val = train_test_split(X_train, Y_train, W_train, test_size=0.1)
print("Training set: ", X_train.shape, Y_train.shape, W_train.shape)
# print("Validation set: ", X_val.shape, Y_val.shape, W_val.shape)
print("Test set: ", X_test.shape, Y_test.shape, W_test.shape)
train_cw = compute_class_weight("balanced", classes=np.unique(Y_train), y=Y_train)
print("Class weights: ", train_cw)
Training set:  (73873, 16) (73873,) (73873,)
Test set:  (3889, 16) (3889,) (3889,)
Class weights:  [0.54872759 5.63056402]

Build LightGBM Classifier

In [18]:
# # Define LightGBM Classifier
# model = LGBMClassifier(boosting_type='gbdt', 
#                        feature_fraction=0.8,  
#                        learning_rate=0.01,
#                        max_bins=64,
#                        max_depth=-1,
#                        min_child_weight=0.001,
#                        min_data_in_leaf=50,
#                        min_split_gain=0.0,
#                        num_iterations=1000,
#                        num_leaves=64,
#                        reg_alpha=0,
#                        reg_lambda=1,
#                        subsample_for_bin=200000,
#                        is_unbalance=True,
#                        random_state=random_state, 
#                        n_jobs=n_jobs_clf, 
#                        silent=True, 
#                        importance_type='split')
In [19]:
# # Fit model on training set
# model.fit(X_train, Y_train, sample_weight=W_train.values, 
#           #categorical_feature=x_cols_categorical,
#           categorical_feature=[])
In [20]:
# # Make predictions on Test set
# predictions = model.predict(X_test)
# print(accuracy_score(Y_test, predictions))
# print(f1_score(Y_test, predictions))
# print(confusion_matrix(Y_test, predictions))
# print(classification_report(Y_test, predictions))
In [21]:
# # Save trained model
# dump(model, f'models/{target_col}-{year}-model.joblib')
# del model
In [22]:
# # Define hyperparameter grid
# param_grid = {
#     'num_leaves': [8, 32, 64],
#     'min_data_in_leaf': [10, 20, 50],
#     'max_depth': [-1], 
#     'learning_rate': [0.01, 0.1], 
#     'num_iterations': [1000, 3000, 5000], 
#     'subsample_for_bin': [200000],
#     'min_split_gain': [0.0], 
#     'min_child_weight': [0.001],
#     'feature_fraction': [0.8, 1.0], 
#     'reg_alpha': [0], 
#     'reg_lambda': [0, 1],
#     'max_bin': [64, 128, 255]
# }
In [23]:
# # Define LightGBM Classifier
# clf = LGBMClassifier(boosting_type='gbdt',
#                      objective='binary', 
#                      is_unbalance=True,
#                      random_state=random_state,
#                      n_jobs=n_jobs_clf, 
#                      silent=True, 
#                      importance_type='split')

# # Define K-fold cross validation splitter
# kfold = StratifiedKFold(n_splits=cv_folds, shuffle=True, random_state=random_state)

# # Perform grid search
# model = GridSearchCV(clf, param_grid=param_grid, scoring='f1', n_jobs=n_jobs_cv, cv=kfold, refit=True, verbose=3)
# model.fit(X_train, Y_train, 
#           sample_weight=W_train.values, 
#           #categorical_feature=x_cols_categorical,
#           categorical_feature=[])

# print('\n All results:')
# print(model.cv_results_)
# print('\n Best estimator:')
# print(model.best_estimator_)
# print('\n Best hyperparameters:')
# print(model.best_params_)
In [24]:
# # Make predictions on Test set
# predictions = model.predict(X_test)
# print(accuracy_score(Y_test, predictions))
# print(f1_score(Y_test, predictions, average='micro'))
# print(confusion_matrix(Y_test, predictions))
# print(classification_report(Y_test, predictions))
In [25]:
# # Save trained model
# dump(model, f'models/{target_col}-{year}-gridsearch.joblib')
# del model

Load LightGBM Classifier

In [26]:
model = load(f'models/{target_col}-{year}-model.joblib')
# model = load(f'models/{target_col}-{year}-gridsearch.joblib').best_estimator_
In [27]:
# Sanity check: Make predictions on Test set
predictions = model.predict(X_test)
print(accuracy_score(Y_test, predictions))
print(f1_score(Y_test, predictions))
print(confusion_matrix(Y_test, predictions))
print(classification_report(Y_test, predictions))
0.8449472872203652
0.40237859266600595
[[3083  461]
 [ 142  203]]
              precision    recall  f1-score   support

           0       0.96      0.87      0.91      3544
           1       0.31      0.59      0.40       345

   micro avg       0.84      0.84      0.84      3889
   macro avg       0.63      0.73      0.66      3889
weighted avg       0.90      0.84      0.87      3889

In [28]:
# Overfitting check: Make predictions on Train set
predictions = model.predict(X_train)
print(accuracy_score(Y_train, predictions))
print(f1_score(Y_train, predictions))
print(confusion_matrix(Y_train, predictions))
print(classification_report(Y_train, predictions))
0.8855874270707836
0.575062845651081
[[59702  7611]
 [  841  5719]]
              precision    recall  f1-score   support

           0       0.99      0.89      0.93     67313
           1       0.43      0.87      0.58      6560

   micro avg       0.89      0.89      0.89     73873
   macro avg       0.71      0.88      0.75     73873
weighted avg       0.94      0.89      0.90     73873


Visualizations/Explainations

Note that these plot just explain how the XGBoost model works, not nessecarily how reality works. Since the XGBoost model is trained from observational data, it is not nessecarily a causal model, and so just because changing a factor makes the model's prediction of winning go up, does not always mean it will raise your actual chances.

In [29]:
# print the JS visualization code to the notebook
shap.initjs()

What makes a measure of feature importance good or bad?

  1. Consistency: Whenever we change a model such that it relies more on a feature, then the attributed importance for that feature should not decrease.
  2. Accuracy. The sum of all the feature importances should sum up to the total importance of the model. (For example if importance is measured by the R² value then the attribution to each feature should sum to the R² of the full model)

If consistency fails to hold, then we can’t compare the attributed feature importances between any two models, because then having a higher assigned attribution doesn’t mean the model actually relies more on that feature.

If accuracy fails to hold then we don’t know how the attributions of each feature combine to represent the output of the whole model. We can’t just normalize the attributions after the method is done since this might break the consistency of the method.

Using Tree SHAP for interpretting the model

In [30]:
explainer = shap.TreeExplainer(model)
# shap_values = explainer.shap_values(dataset)
shap_values = pickle.load(open(f'res/{target_col}-{year}-shapvals.obj', 'rb'))
In [31]:
# Visualize a single prediction
shap.force_plot(explainer.expected_value, shap_values[0,:], dataset_display.iloc[0,:])
Out[31]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

The above explanation shows features each contributing to push the model output from the base value (the average model output over the training dataset we passed) to the model output. Features pushing the prediction higher are shown in red, those pushing the prediction lower are in blue.

If we take many explanations such as the one shown above, rotate them 90 degrees, and then stack them horizontally, we can see explanations for an entire dataset (in the notebook this plot is interactive):

In [32]:
# Visualize many predictions
subsample = np.random.choice(len(dataset), 1000)  # Take random sub-sample
shap.force_plot(explainer.expected_value, shap_values[subsample,:], dataset_display.iloc[subsample,:])
Out[32]:
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

Summary Plots

In [33]:
for col, sv in zip(dataset.columns, np.abs(shap_values).mean(0)):
    print(f"{col} - {sv}")
state - 0.45335921099130844
hh_religion - 0.1690947634414559
urban - 0.3215705789618594
women_anemic - 0.08293760494682353
total_children - 0.21364030323961417
obese_female - 0.03363755668378388
age - 0.3740933611101402
years_edu - 0.8715166560941714
wealth_index - 0.31218552873878075
hh_members - 0.3039595402954194
no_children_below5 - 0.14037130551222188
freq_tv - 0.12573968801658197
caste_don't know - 0.010316911480992905
caste_general - 0.11941058433005432
caste_other backward class - 0.07916777217078158
caste_sc/st - 0.04203507192069545
In [34]:
shap.summary_plot(shap_values, dataset, plot_type="bar")

The above figure shows the global mean(|Tree SHAP|) method applied to our model.

The x-axis is essentially the average magnitude change in model output when a feature is “hidden” from the model (for this model the output has log-odds units). “Hidden” means integrating the variable out of the model. Since the impact of hiding a feature changes depending on what other features are also hidden, Shapley values are used to enforce consistency and accuracy.

However, since we now have individualized explanations for every person in our dataset, to get an overview of which features are most important for a model we can plot the SHAP values of every feature for every sample. The plot below sorts features by the sum of SHAP value magnitudes over all samples, and uses SHAP values to show the distribution of the impacts each feature has on the model output. The color represents the feature value (red high, blue low):

In [35]:
shap.summary_plot(shap_values, dataset_display)
  • Every person has one dot on each row.
  • The x position of the dot is the impact of that feature on the model’s prediction for the person.
  • The color of the dot represents the value of that feature for the customer. Categorical variables are colored grey.
  • Dots that don’t fit on the row pile up to show density (since our dataset is large).
  • Since the XGBoost model has a logistic loss the x-axis has units of log-odds (Tree SHAP explains the change in the margin output of the model).

How to use this: We can make analysis similar to the blog post for interpretting our models.


SHAP Dependence Plots

Next, to understand how a single feature effects the output of the model we can plot the SHAP value of that feature vs. the value of the feature for all the examples in a dataset. SHAP dependence plots show the effect of a single feature across the whole dataset. They plot a feature's value vs. the SHAP value of that feature across many samples.

SHAP dependence plots are similar to partial dependence plots, but account for the interaction effects present in the features, and are only defined in regions of the input space supported by data. The vertical dispersion of SHAP values at a single feature value is driven by interaction effects, and another feature is chosen for coloring to highlight possible interactions. One the benefits of SHAP dependence plots over traditional partial dependence plots is this ability to distigush between between models with and without interaction terms. In other words, SHAP dependence plots give an idea of the magnitude of the interaction terms through the vertical variance of the scatter plot at a given feature value.

Good example of using Dependency Plots: https://slundberg.github.io/shap/notebooks/League%20of%20Legends%20Win%20Prediction%20with%20XGBoost.html

Plots for 'age'

In [36]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('age', 'age'),
         ('age', 'urban'),
         ('age', 'caste_sc/st'),
         ('age', 'caste_general'),
         ('age', 'wealth_index'),
         ('age', 'years_edu'),
         ('age', 'no_children_below5'),
         ('age', 'total_children'),
         ('hh_religion', 'age'),
         ('state', 'age')]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: age, Interaction Feature: age
Feature: age, Interaction Feature: urban
Feature: age, Interaction Feature: caste_sc/st
Feature: age, Interaction Feature: caste_general
Feature: age, Interaction Feature: wealth_index
Feature: age, Interaction Feature: years_edu
Feature: age, Interaction Feature: no_children_below5
Feature: age, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: age
Feature: state, Interaction Feature: age

Plots for 'wealth_index'

In [37]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('wealth_index', 'wealth_index'),
         ('wealth_index', 'age'), 
         ('wealth_index', 'urban'),
         ('wealth_index', 'caste_sc/st'),
         ('wealth_index', 'caste_general'),
         ('wealth_index', 'years_edu'),
         ('wealth_index', 'no_children_below5'),
         ('wealth_index', 'total_children'),
         ('hh_religion', 'wealth_index'),
         ('state', 'wealth_index')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: wealth_index, Interaction Feature: wealth_index
Feature: wealth_index, Interaction Feature: age
Feature: wealth_index, Interaction Feature: urban
Feature: wealth_index, Interaction Feature: caste_sc/st
Feature: wealth_index, Interaction Feature: caste_general
Feature: wealth_index, Interaction Feature: years_edu
Feature: wealth_index, Interaction Feature: no_children_below5
Feature: wealth_index, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: wealth_index
Feature: state, Interaction Feature: wealth_index

Plots for 'years_edu'

In [38]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('years_edu', 'years_edu'),
         ('years_edu', 'age'), 
         ('years_edu', 'urban'),
         ('years_edu', 'caste_sc/st'),
         ('years_edu', 'caste_general'),
         ('years_edu', 'wealth_index'),
         ('years_edu', 'no_children_below5'),
         ('years_edu', 'total_children'),
         ('hh_religion', 'years_edu'),
         ('state', 'years_edu')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: years_edu, Interaction Feature: years_edu
Feature: years_edu, Interaction Feature: age
Feature: years_edu, Interaction Feature: urban
Feature: years_edu, Interaction Feature: caste_sc/st
Feature: years_edu, Interaction Feature: caste_general
Feature: years_edu, Interaction Feature: wealth_index
Feature: years_edu, Interaction Feature: no_children_below5
Feature: years_edu, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: years_edu
Feature: state, Interaction Feature: years_edu

Plots for 'caste_sc/st'

In [39]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_sc/st', 'caste_sc/st'),
         ('caste_sc/st', 'age'), 
         ('caste_sc/st', 'urban'),
         ('caste_sc/st', 'years_edu'),
         ('caste_sc/st', 'wealth_index'),
         ('caste_sc/st', 'no_children_below5'),
         ('caste_sc/st', 'total_children'),
         ('hh_religion', 'caste_sc/st'),
         ('state', 'caste_sc/st')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: caste_sc/st, Interaction Feature: caste_sc/st
Feature: caste_sc/st, Interaction Feature: age
Feature: caste_sc/st, Interaction Feature: urban
Feature: caste_sc/st, Interaction Feature: years_edu
Feature: caste_sc/st, Interaction Feature: wealth_index
Feature: caste_sc/st, Interaction Feature: no_children_below5
Feature: caste_sc/st, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: caste_sc/st
Feature: state, Interaction Feature: caste_sc/st

Plots for 'caste_general'

In [40]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_general', 'caste_general'),
         ('caste_general', 'age'), 
         ('caste_general', 'urban'),
         ('caste_general', 'years_edu'),
         ('caste_general', 'wealth_index'),
         ('caste_general', 'no_children_below5'),
         ('caste_general', 'total_children'),
         ('hh_religion', 'caste_general'),
         ('state', 'caste_general')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(col_name, shap_values, dataset, display_features=dataset_display, interaction_index=int_col_name)
Feature: caste_general, Interaction Feature: caste_general
Feature: caste_general, Interaction Feature: age
Feature: caste_general, Interaction Feature: urban
Feature: caste_general, Interaction Feature: years_edu
Feature: caste_general, Interaction Feature: wealth_index
Feature: caste_general, Interaction Feature: no_children_below5
Feature: caste_general, Interaction Feature: total_children
Feature: hh_religion, Interaction Feature: caste_general
Feature: state, Interaction Feature: caste_general

Visualizing Bar/Summary plots split by age bins

In [41]:
bins = [(21,25), (26,30), (31,35), (36,40), (41,45), (46,50)]

for low, high in bins:
    # Sample dataset by age range
    dataset_sample = dataset[(dataset.age > low) & (dataset.age <= high)]
    dataset_display_sample = dataset_display[(dataset.age > low) & (dataset.age <= high)]
    targets_sample = targets[(dataset.age > low) & (dataset.age <= high)]
    shap_values_sample = shap_values[(dataset.age > low) & (dataset.age <= high)]
    
    print("\nAge Range: {} - {} years".format(low, high))
    print("Sample size: {}\n".format(len(dataset_sample)))
    
    for col, sv in zip(dataset_sample.columns, np.abs(shap_values_sample).mean(0)):
        print(f"{col} - {sv}")
    
    # Summary plots
    shap.summary_plot(shap_values_sample, dataset_sample, plot_type="bar")
    shap.summary_plot(shap_values_sample, dataset_display_sample)
Age Range: 21 - 25 years
Sample size: 13487

state - 0.46992020738034157
hh_religion - 0.18140922934856257
urban - 0.3336572786421281
women_anemic - 0.08100022045221847
total_children - 0.23465669694469798
obese_female - 0.027298998553193628
age - 0.46403324425528997
years_edu - 0.9293177879556868
wealth_index - 0.2648200500809891
hh_members - 0.31155558901911545
no_children_below5 - 0.19562385934566967
freq_tv - 0.11949153975737899
caste_don't know - 0.00955178346883499
caste_general - 0.07850655661546928
caste_other backward class - 0.11194102036022203
caste_sc/st - 0.053765037171127114
Age Range: 26 - 30 years
Sample size: 13101

state - 0.4398593836668961
hh_religion - 0.16093178236486166
urban - 0.28281029493957893
women_anemic - 0.07881923100376449
total_children - 0.16492927927455714
obese_female - 0.029179571239632825
age - 0.1965807902429926
years_edu - 0.9114362861441756
wealth_index - 0.31352332209703065
hh_members - 0.3099084515103895
no_children_below5 - 0.1574901978994082
freq_tv - 0.12642958828987888
caste_don't know - 0.011928131442362228
caste_general - 0.10515081853674038
caste_other backward class - 0.08465401106194358
caste_sc/st - 0.043494719959466986
Age Range: 31 - 35 years
Sample size: 12209

state - 0.4466064518955531
hh_religion - 0.16833700795852488
urban - 0.28690186512651755
women_anemic - 0.08958444109859162
total_children - 0.15893571072433432
obese_female - 0.034972296299621904
age - 0.2426677743332041
years_edu - 0.8294093971917675
wealth_index - 0.35028010339678173
hh_members - 0.33128707213685504
no_children_below5 - 0.12175428309781369
freq_tv - 0.13497746953110404
caste_don't know - 0.010463038364647202
caste_general - 0.1287518831284135
caste_other backward class - 0.06437607704153163
caste_sc/st - 0.038056741568570036
Age Range: 36 - 40 years
Sample size: 10658

state - 0.45032308097086843
hh_religion - 0.1758026740158495
urban - 0.3315426257998196
women_anemic - 0.08231044946467116
total_children - 0.20624755786662563
obese_female - 0.036196057971379945
age - 0.4581853088384662
years_edu - 0.8213589195791061
wealth_index - 0.3278563762845138
hh_members - 0.30625942675841855
no_children_below5 - 0.09713633451602675
freq_tv - 0.12516546266234513
caste_don't know - 0.009782415863665808
caste_general - 0.1392578259238945
caste_other backward class - 0.05973221171193017
caste_sc/st - 0.03616842791722035
Age Range: 41 - 45 years
Sample size: 8613

state - 0.4413079684938689
hh_religion - 0.17142473518064053
urban - 0.3676142404738651
women_anemic - 0.08361906229383329
total_children - 0.23802327849981658
obese_female - 0.04178861167900035
age - 0.24577466275991716
years_edu - 0.8471995528162141
wealth_index - 0.31874165875470584
hh_members - 0.2710189471583502
no_children_below5 - 0.10612826947891515
freq_tv - 0.12014911926494093
caste_don't know - 0.009687906407648618
caste_general - 0.14174825451845444
caste_other backward class - 0.06927714497712052
caste_sc/st - 0.03534778909028783
Age Range: 46 - 50 years
Sample size: 4290

state - 0.4829492861558039
hh_religion - 0.16589825736226887
urban - 0.3412939726916721
women_anemic - 0.08722171557794567
total_children - 0.212851985804436
obese_female - 0.0448800516298512
age - 0.27564373178336365
years_edu - 0.8495297920013909
wealth_index - 0.29507633249839177
hh_members - 0.26089706388378575
no_children_below5 - 0.1324396389457257
freq_tv - 0.1536110694669397
caste_don't know - 0.010734939679287412
caste_general - 0.1794589788944331
caste_other backward class - 0.05264753862108067
caste_sc/st - 0.034467692303172565

SHAP Interaction Values

SHAP interaction values are a generalization of SHAP values to higher order interactions.

The model returns a matrix for every prediction, where the main effects are on the diagonal and the interaction effects are off-diagonal. The main effects are similar to the SHAP values you would get for a linear model, and the interaction effects captures all the higher-order interactions are divide them up among the pairwise interaction terms.

Note that the sum of the entire interaction matrix is the difference between the model's current output and expected output, and so the interaction effects on the off-diagonal are split in half (since there are two of each). When plotting interaction effects the SHAP package automatically multiplies the off-diagonal values by two to get the full interaction effect.

In [42]:
# Sample from dataset based on sample weights
dataset_ss = dataset.sample(10000, weights=sample_weights, random_state=random_state)
print(dataset_ss.shape)
dataset_display_ss = dataset_display.loc[dataset_ss.index]
print(dataset_display_ss.shape)
(10000, 16)
(10000, 16)
In [43]:
# Compute SHAP interaction values (time consuming)
# shap_interaction_values = explainer.shap_interaction_values(dataset_ss)
shap_interaction_values = pickle.load(open(f'res/{target_col}-{year}-shapints.obj', 'rb'))
In [44]:
shap.summary_plot(shap_interaction_values, dataset_display_ss, max_display=15)

Heatmap of SHAP Interaction Values

In [45]:
tmp = np.abs(shap_interaction_values).sum(0)
for i in range(tmp.shape[0]):
    tmp[i,i] = 0
inds = np.argsort(-tmp.sum(0))[:50]
tmp2 = tmp[inds,:][:,inds]
pl.figure(figsize=(12,12))
pl.imshow(tmp2)
pl.yticks(range(tmp2.shape[0]), dataset_ss.columns[inds], rotation=50.4, horizontalalignment="right")
pl.xticks(range(tmp2.shape[0]), dataset_ss.columns[inds], rotation=50.4, horizontalalignment="left")
pl.gca().xaxis.tick_top()
pl.show()

SHAP Interaction Value Dependence Plots

Running a dependence plot on the SHAP interaction values a allows us to separately observe the main effects and the interaction effects.

Below we plot the main effects for age and some of the interaction effects for age. It is informative to compare the main effects plot of age with the earlier SHAP value plot for age. The main effects plot has no vertical dispersion because the interaction effects are all captured in the off-diagonal terms.

Good example of how to infer interesting stuff from interaction values: https://slundberg.github.io/shap/notebooks/NHANES%20I%20Survival%20Model.html

In [46]:
shap.dependence_plot(
    ("age", "age"), 
    shap_interaction_values, dataset_ss, display_features=dataset_display_ss
)

Now we plot the interaction effects involving age (and other features after that). These effects capture all of the vertical dispersion that was present in the original SHAP plot but is missing from the main effects plot above.

Plots for 'age'

In [47]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('age', 'age'),
         ('age', 'urban'),
         ('age', 'caste_sc/st'),
         ('age', 'caste_general'),
         ('age', 'wealth_index'),
         ('age', 'years_edu'),
         ('age', 'no_children_below5'),
         ('age', 'total_children')]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: age, Interaction Feature: age
Feature: age, Interaction Feature: urban
Feature: age, Interaction Feature: caste_sc/st
Feature: age, Interaction Feature: caste_general
Feature: age, Interaction Feature: wealth_index
Feature: age, Interaction Feature: years_edu
Feature: age, Interaction Feature: no_children_below5
Feature: age, Interaction Feature: total_children

Plots for 'wealth_index'

In [48]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('wealth_index', 'wealth_index'),
         ('wealth_index', 'age'), 
         ('wealth_index', 'urban'),
         ('wealth_index', 'caste_sc/st'),
         ('wealth_index', 'caste_general'),
         ('wealth_index', 'years_edu'),
         ('wealth_index', 'no_children_below5'),
         ('wealth_index', 'total_children')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: wealth_index, Interaction Feature: wealth_index
Feature: wealth_index, Interaction Feature: age
Feature: wealth_index, Interaction Feature: urban
Feature: wealth_index, Interaction Feature: caste_sc/st
Feature: wealth_index, Interaction Feature: caste_general
Feature: wealth_index, Interaction Feature: years_edu
Feature: wealth_index, Interaction Feature: no_children_below5
Feature: wealth_index, Interaction Feature: total_children

Plots for 'years_edu'

In [49]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('years_edu', 'years_edu'),
         ('years_edu', 'age'), 
         ('years_edu', 'urban'),
         ('years_edu', 'caste_sc/st'),
         ('years_edu', 'caste_general'),
         ('years_edu', 'wealth_index'),
         ('years_edu', 'no_children_below5'),
         ('years_edu', 'total_children')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: years_edu, Interaction Feature: years_edu
Feature: years_edu, Interaction Feature: age
Feature: years_edu, Interaction Feature: urban
Feature: years_edu, Interaction Feature: caste_sc/st
Feature: years_edu, Interaction Feature: caste_general
Feature: years_edu, Interaction Feature: wealth_index
Feature: years_edu, Interaction Feature: no_children_below5
Feature: years_edu, Interaction Feature: total_children

Plots for 'caste_sc/st'

In [50]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_sc/st', 'caste_sc/st'),
         ('caste_sc/st', 'age'), 
         ('caste_sc/st', 'urban'),
         ('caste_sc/st', 'years_edu'),
         ('caste_sc/st', 'wealth_index'),
         ('caste_sc/st', 'no_children_below5'),
         ('caste_sc/st', 'total_children')
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: caste_sc/st, Interaction Feature: caste_sc/st
Feature: caste_sc/st, Interaction Feature: age
Feature: caste_sc/st, Interaction Feature: urban
Feature: caste_sc/st, Interaction Feature: years_edu
Feature: caste_sc/st, Interaction Feature: wealth_index
Feature: caste_sc/st, Interaction Feature: no_children_below5
Feature: caste_sc/st, Interaction Feature: total_children

Plots for 'caste_general'

In [51]:
# Define pairs of features and interaction indices for dependence plots
pairs = [('caste_general', 'caste_general'),
         ('caste_general', 'age'), 
         ('caste_general', 'urban'),
         ('caste_general', 'years_edu'),
         ('caste_general', 'wealth_index'),
         ('caste_general', 'no_children_below5'),
         ('caste_general', 'total_children'),
        ]

# Dependence plots between pairs
for col_name, int_col_name in pairs:
    print(f"\nFeature: {col_name}, Interaction Feature: {int_col_name}")
    shap.dependence_plot(
        (col_name, int_col_name), 
        shap_interaction_values, dataset_ss, display_features=dataset_display_ss
    )
Feature: caste_general, Interaction Feature: caste_general
Feature: caste_general, Interaction Feature: age
Feature: caste_general, Interaction Feature: urban
Feature: caste_general, Interaction Feature: years_edu
Feature: caste_general, Interaction Feature: wealth_index
Feature: caste_general, Interaction Feature: no_children_below5
Feature: caste_general, Interaction Feature: total_children